import re
from collections import namedtuple
from argparse import Namespace
from typing import NamedTuple
from functools import wraps, partial

import click
import h5py
import numpy as np
import torch

from .bayesopt.util import BACKENDS
from .utils import DebugSet
from .bayesopt.gp import KERNELS, INDUCERS
from .bayesopt.bo import SMOOptimizer, NFTOptimizer, EMICOREOptimizer, SUBSCOREOptimizer
from .bayesopt.bo import SGDOptimizer, SGDGPOptimizer, GRADCOREOptimizer
from .bayesopt.bo import ExpectedImprovement


class DataLog(dict):
    '''Dict that appends keys to lists.'''
    def __init__(self, data=None, fname=None, prefix=''):
        super().__init__({} if data is None else data)
        self.fname = fname
        self.prefix = prefix

    def __setitem__(self, key, value):
        '''Appends value to list under keys.'''
        if value is None:
            return
        self.init(key)
        self[key].append(value)

    def update(self, obj):
        for key, val in dict(obj).items():
            self[key] = val

    def init(self, key):
        '''Initialize key to empty dict'''
        if key not in self:
            super().__setitem__(key, [])

    def extend(self, key, values):
        '''Extends list under keys by values.'''
        self.init(key)
        if values:
            self[key].extend([value for value in values if value is not None])

    def update_extend(self, obj):
        '''Updates dict by extending all by all lists in obj.'''
        for key, val in dict(obj).items():
            self.extend(key, val)

    def numpy(self, key):
        '''Return a numpy representation of the log at key.'''
        value = self[key]
        if not value:
            return np.array([])
        value = [val.numpy() if isinstance(val, torch.Tensor) else val for val in value]
        if isinstance(value[-1], np.ndarray):
            value = np.stack(value, axis=0)
        else:
            value = np.array(value)
        return value

    def flush(self, fname=None, prefix=None, clear=True, overwrite=False):
        '''Write log to hdf5 file and clear the log.'''
        if fname is None:
            fname = self.fname
        if fname is None:
            raise RuntimeError('No filename supplied, cannot flush!')
        if prefix is None:
            prefix = self.prefix
        if prefix is None:
            prefix = ''

        with h5py.File(fname, 'w' if overwrite else 'a') as fd:
            for base_key in self:
                key = f'{prefix}/{base_key}'
                data = self.numpy(base_key)
                if key not in fd:
                    subshp = tuple(data.shape[1:])
                    fd.create_dataset(
                        key,
                        shape=data.shape,
                        dtype=data.dtype,
                        maxshape=(None,) + subshp,
                        data=data,
                        chunks=True,
                        compression='gzip'
                    )
                else:
                    size = fd[key].shape[0]
                    fd[key].resize(size + data.shape[0], axis=0)
                    fd[key][size:] = data

        if clear:
            for key in self:
                self[key].clear()


class FinalProperties:
    def __init_subclass__(cls, *args, **kwargs):
        updates = []
        for key, value in cls.__dict__.items():
            if key[:2] != '__' and callable(value) and hasattr(value, '_final_properties'):
                updates += [(name, property(partial(value, key=name))) for name in value._final_properties]

        for name, prop in updates:
            setattr(cls, name, prop)

    def __init__(self):
        self._dict = {}


def final_property(func):
    @property
    @wraps(func)
    def wrapped(self):
        try:
            return self._dict[func.__name__]
        except KeyError:
            result = func(self)
            self._dict[func.__name__] = result
            return result
    return wrapped


def final_properties(*names):
    def wrapping(func):
        @wraps(func)
        def wrapped(self, key):
            try:
                return self._dict[key]
            except KeyError:
                result = func(self)
                self._dict.update(zip(names, result))
                return self._dict[key]

        wrapped._final_properties = names
        return wrapped
    return wrapping


def csobj(dtype, sep=',', maxsplit=-1, length=-1, container=list):
    def wrapped(string):
        if isinstance(string, tuple):
            return string
        result = container(dtype(elem) for elem in string.split(sep, maxsplit) if elem)
        if length > 0 and len(result) != length:
            raise RuntimeError(f'Invalid number of fields. Provided {len(result)} but expected {length}!')
        return result
    return wrapped


def option_dict(string):
    if isinstance(string, dict):
        return string
    return dict([elem.split('=', 1) for elem in string.split(',') if elem])


def _append_param(func, param):
    if isinstance(func, click.Command):
        func.params.append(param)
    else:
        if not hasattr(func, '__click_params__'):
            func.__click_params__ = []
        func.__click_params__.append(param)


def namedtuple_as_dict(input):
    if isinstance(input, tuple) and hasattr(input, '_asdict'):
        return namedtuple_as_dict(input._asdict())
    if isinstance(input, Namespace):
        return namedtuple_as_dict(vars(Namespace))
    if isinstance(input, dict):
        return {key: namedtuple_as_dict(value) for key, value in input.items()}
    if isinstance(input, (tuple, list, set)):
        return [namedtuple_as_dict(value) for value in input]
    if hasattr(input, '_dictsource'):
        return input._dictsource
    return input


class Data(NamedTuple):
    x: torch.Tensor
    y: torch.Tensor
    readout: torch.Tensor = None


class TrueSolution(NamedTuple):
    e0: torch.Tensor
    e1: torch.Tensor
    wf: torch.Tensor


class PositiveFloatParam(click.ParamType):
    name = 'positive_float'

    def convert(self, value, param, ctx):
        if not isinstance(value, float):
            value = float(value)
        if value <= 0.:
            self.fail(f'Value {value} is non-positive!', param, ctx)
        return value

    def __repr__(self):
        return 'positive_float'


PositiveFloat = PositiveFloatParam()


class ChoiceOrFloat(click.Choice):
    name = 'choice_or_float'

    def convert(self, value, param, ctx):
        try:
            return super().convert(value, param, ctx)
        except click.BadParameter as error:
            try:
                return click.FLOAT(value)
            except click.BadParameter:
                raise error


class AnnotatedFuncFromDict(click.ParamType):
    name = 'AnnotatedFuncFromDict'

    def __init__(self, fndict, partial=False):
        super().__init__()
        self.fndict = fndict
        self.partial = partial

    def convert(self, value, param, ctx):
        if callable(value):
            return value

        name, *kwargstr = value.split(':')
        if name not in self.fndict:
            available = '\', \''.join(self.fndict)
            raise click.BadParameter(
                f"No such function: '{name}'. Available functions are '{available}'")
        func = self.fndict[name]

        if func is None:
            return None

        if isinstance(func, type):
            annotations = func.__init__.__annotations__
        else:
            annotations = func.__annotations__

        kwargtups = dict([obj.split('=', 1) for obj in kwargstr if obj])
        missing = set(kwargtups).difference(annotations)
        if missing:
            invalid = '\', \''.join(missing)
            available = '\', \''.join(annotations)
            raise click.BadParameter(
                f"No such arguments for function '{name}': '{invalid}'. "
                f"Valid arguments are: '{available}'"
            )

        kwargs = {key: annotations[key](val) for key, val in kwargtups.items()}

        if self.partial:
            retval = partial(func, **kwargs)
        else:
            retval = func(**kwargs)

        try:
            retval._dictsource = {'funcname': name, **kwargs}
        except AttributeError:
            pass

        return retval


class OptionParams(click.ParamType):
    name = 'OptionParams'
    _rexp = re.compile(
        r'(?P<key>[^\s#=]+)=(?P<value>[^#,\n]+)[,\n]?|'
        r'(?P<comment>#[^\n]*$)|'
        r'(?P<whitespace>\s+)|'
        r'(?P<error>.+)',
        re.MULTILINE
    )

    @classmethod
    def __init_subclass__(cls, *args, **kwargs):
        super().__init_subclass__(*args, **kwargs)

        cls._types = {
            key: value
            for key, value in cls.__annotations__.items()
            if key not in click.ParamType.__annotations__
        }
        cls._defaults = {}
        cls._help = {}
        for key in cls._types:
            obj = getattr(cls, key, None)
            if isinstance(obj, tuple) and len(obj) == 2 and isinstance(obj[1], str):
                cls._defaults[key] = obj[0]
                cls._help[key] = obj[1]
            else:
                cls._defaults[key] = obj

        cls._return_type = namedtuple(
            cls.__name__,
            list(cls.__annotations__),
            defaults=[cls._defaults[key] for key in cls.__annotations__]
        )

    @classmethod
    def _parse(cls, string):
        for mt in cls._rexp.finditer(string):
            if mt['error'] is not None:
                raise RuntimeError(f'Parsing options failed at \'{mt["error"]}\'!')
            if mt.lastgroup in ('key', 'value'):
                yield (mt['key'], mt['value'])

    @classmethod
    def _members(cls):
        for key, dtype in cls._types.items():
            yield (key, dtype, cls._help.get(key, ''), cls._defaults.get(key, None))

    @classmethod
    def options(cls, prefix='', namespace=None, help=None):
        prefix = prefix.replace('_', '-')

        def decorator(func):
            for key, dtype, dhelp, default in cls._members():
                key = key.replace('_', '-')
                param = click.Option(
                    [f'--{prefix}-{key}' if prefix else f'--{key}'],
                    type=dtype,
                    help=dhelp,
                    default=default,
                    callback=cls._callback(namespace),
                    expose_value=False
                )
                _append_param(func, param)
            return func
        return decorator

    @classmethod
    def _callback(cls, namespace=None):
        def callback(ctx, param, value):
            key = param.name.replace('-', '_')
            if namespace:
                setattr(ctx.params.setdefault(namespace, Namespace()), key, value)
            else:
                ctx.params[key] = value
            return value
        return callback

    def convert(self, value, param, ctx):
        retval = {}
        if isinstance(value, str):
            for key, val in self._parse(value):
                key = key.replace('-', '_')
                if key not in self._types:
                    raise RuntimeError(f'No such option for {param.name}: {key}!')
                retval[key] = self._types[key](val)
        if isinstance(value, self._return_type):
            return value

        return self._return_type(**retval)

    def __repr__(self):
        members = {
            key.replace('_', '-'): (getattr(dtype, '__name__', str(dtype)), self._defaults.get(key, None))
            for key, dtype in self._types.items()
        }
        msg = ', '.join(sorted([
            f'{key}:{dtype}={default}' for key, (dtype, default) in members.items()
        ]))

        return f'{type(self).__name__}({msg})'

    def get_metavar(self, param):
        members = {
            key.replace('_', '-'): (
                getattr(dtype, '__name__', str(dtype)), self._defaults.get(key, None), self._help.get(key, '')
            )
            for key, dtype in self._types.items()
        }
        lines = [
            (f'{" " * 6}{key}={default},', f'  # <{dtype}> {dhelp}')
            for key, (dtype, default, dhelp) in members.items()
        ]
        maxlen = max(len(line[0]) for line in lines)
        msg = '\n'.join(sorted([
            f'{head:{maxlen}s}{tail}' for head, tail in lines
        ]))

        return f'\'\n{msg}\n\''

    @classmethod
    @property
    def defaults(cls):
        for key, dtype in cls._types.items():
            val = cls._defaults.get(key, None)
            yield (key, dtype(val) if val is not None else val)


OPTIMIZER_SETUPS = {
    'nftgp': (SMOOptimizer, ('stabilize_interval', 'shift_mode')),
    'nft': (NFTOptimizer, ('stabilize_interval', 'shift_mode')),
    'subscore': (SUBSCOREOptimizer, (
        'gridsize',
        'corethresh',
        'corethresh_width',
        'corethresh_scale',
        'corethresh_power',
        'coremin_scale',
        'corethresh_shift',
        'single_readout_var',
        'readout_strategy',
        'corethresh_strategy',
        'coremetric',
        'coremargin',
        'coremomentum',
        'coreref',
        'shift_mode',
        'pnorm',
    )),
    'emicore': (EMICOREOptimizer, (
        'stabilize_interval',
        'gridsize',
        'pairsize',
        'samplesize',
        'core_trials',
        'corethresh',
        'corethresh_width',
        'corethresh_scale',
        'corethresh_power',
        'coremin_scale',
        'corethresh_shift',
        'corethresh_strategy',
        'coremetric',
        'coremargin',
        'coremomentum',
        'coreref',
        'single_readout_var',
        'smo_steps',
        'smo_axis',
        'pivot_steps',
        'pivot_scale',
        'pivot_mode',
        'pnorm',
    )),
    'sgd': (SGDOptimizer, ('lr', 'momentum', 'gdoptim')),
    'sgdgp': (SGDGPOptimizer, ('lr', 'momentum', 'gdoptim')),
    'gradcore': (GRADCOREOptimizer, (
        'lr',
        'momentum',
        'gdoptim',
        'corethresh',
        'corethresh_width',
        'corethresh_scale',
        'corethresh_power',
        'coremin_scale',
        'corethresh_shift',
        'single_readout_var',
        'readout_strategy',
        'corethresh_strategy',
        'coremetric',
        'coremargin',
        'coremomentum',
        'coreref',
        'pnorm',
    )),
}


ACQUISITION_FNS = {
    'ei': (ExpectedImprovement, ()),
}


class QCParams(OptionParams):
    n_layers: int = 3, 'Number of circuit layers'
    n_qbits: int = 5, 'Number of QBits'
    sector: int = 1, 'Sector -1 or 1'
    n_readout: int = 512, 'Number of shots'
    j_coupling: csobj(float, length=3) = '-1,0,0', 'Nearest Neigh. interaction coupling'
    h_coupling: csobj(float, length=3) = '0,0,-1', 'External magnetic field coupling'
    pbc: click.BOOL = False, 'Set Periodic/Open Boundary Conditions PBC or OBC. OBC default'
    circuit: click.Choice(['generic', 'esu2']) = 'esu2', 'Circuit name'  # noqa: F821
    backend: click.Choice(list(BACKENDS)) = 'qiskit', 'Backend for QC'
    noise_level: float = 0.0, 'Circuit noise level'
    free_angles: int = None, 'number of free angles'
    assume_exact: click.BOOL = False, 'Assume energy is exact or an estimate.'
    cache: click.Path(dir_okay=False) = None, 'Cache for ground state wave function and initial train data.'
    train_data_mode: click.Choice(('cache', 'compute')) = 'compute', 'Inital data mode'  # noqa: F821


class KernelParams(OptionParams):
    sigma_0: PositiveFloat = 1.0, 'Prior variance'
    gamma: PositiveFloat = 2.45, 'Kernel width parameter'


class GPParams(OptionParams):
    kernel: click.Choice(list(KERNELS)) = 'vqe', 'Name of the kernel'
    y_var_default: float = 1e-10, 'Default observation noise'
    y_var_default_estimates: int = None, 'Number of estimates for y_var_default'
    kernel_params: KernelParams() = '', 'Kernel options'
    inducer: AnnotatedFuncFromDict(INDUCERS) = None, 'Method of inducing point selection'


class AcqParams(OptionParams):
    func: click.Choice(list(ACQUISITION_FNS)) = 'ei', 'BO acquisition function'  # noqa: F821
    optim: click.Choice(list(OPTIMIZER_SETUPS)) = 'subscore', ''
    lr: float = 1., ''
    momentum: float = 0.9, ''
    gdoptim: click.Choice(['sgd', 'adam']) = 'sgd', ''  # noqa: F821
    n_iter: int = None, ''
    max_iter: int = 200, ''
    max_eval: int = None, ''
    max_ls: int = None, ''
    gtol: float = None, ''
    gridsize: int = None, ''
    weighted: click.BOOL = None, ''
    stabilize_interval: int = None, 'Stabilize by measuring the center in NFT/SMO every n-th step'
    seq_reg: float = 0.0, ''
    seq_reg_init: int = -20, ''
    pairsize: int = 20, ''
    gridsize: int = 100, ''
    samplesize: int = 100, ''
    corethresh: float = 1.0, ''
    corethresh_width: int = 10, ''
    corethresh_scale: float = 1.0, ''
    corethresh_power: float = 1.0, ''
    coremin_scale: float = 0.0, ''
    corethresh_shift: float = 0.0, ''
    core_trials: int = 0, ''
    smo_steps: int = 0, ''
    smo_axis: click.BOOL = True, ''
    pivot_steps: int = 0, ''
    pivot_scale: float = 1.0, ''
    pivot_mode: click.Choice(['smo', 'loop']) = 'smo', ''  # noqa: F821
    mc_steps: int = 0, ''
    single_readout_var: float = None, 'Variance of a single measurement. Estimated from readout if None.'
    readout_strategy: click.Choice(['max', 'core', 'center']) = (  # noqa: F821
        'max', 'Strategy to translate core-threshold to shots'
    )
    corethresh_strategy: click.Choice([
        'last', 'lastabs', 'linreg', 'avg', 'grad', 'cheatgrad', 'grad_shift', 'cheatgrad_shift'  # noqa: F821
    ]) = (
        'last', 'Strategy to estimate core-threshold'
    )
    coremetric: click.Choice(['std', 'readout']) = (  # noqa: F821
        'std', 'Controls whether coremin_scale and corethresh are scale/std, or absolute readout'
    )
    coremargin: float = 1.0, ''
    shift_mode: click.Choice(['pi3', '2pi3', 'pi2', '5pi8', '3pi8', 'pi6']) = (  # noqa: F821, F722
        '2pi3', 'Shift in NFT; 2pi3: 2pi/3, pi3: pi/3, pi2: pi/2, pi6: pi/6'
    )
    coremomentum: int = 0, 'Moving average width for the corethresh'
    coreref: click.Choice(['cur', 'best']) = (  # noqa: F821
        'cur', 'Controls whether the current point, or the best so far should be used to update the corethresh'
    )
    pnorm: float = 2.0, ''


class HyperParams(OptionParams):
    optim: click.Choice(['adam', 'grid', 'none']) = 'none', 'Hyperparam Optimization'  # noqa: F821
    loss: click.Choice(['mll', 'loo']) = 'mll', 'Loss for hyperparm optimization'  # noqa: F821
    lr: float = 1e-3, 'Learning rate for SGD-based hyperparm optim'
    threshold: float = 0.0, ''
    steps: int = 90, 'Number of steps in grid search'
    interval: str = '', ''
    max_gamma: float = 20.0, 'Maximum gamma in grid search'


class BOParams(OptionParams):
    train_samples: int = 1, 'Number of training samples'
    candidate_samples: int = 500, 'Number of acquisition function candidates'
    candidate_shots: int = 10, 'Acquisition function candidate multiplier'
    n_iter: int = 50, 'Iteration for Bayesian Optimization'
    acq_params: AcqParams() = '', 'Parameters for BO acquisition function'
    hyperopt: HyperParams() = '', 'Parameters for Kernel hyperparam optimization'
    iter_mode: click.Choice(['step', 'qc', 'readout']) = (  # noqa: F821
        'step',
        'Changes behavior of --n-iter: step are BO-steps, qc counts calls to QC, readout counts single QC readouts'
    )
    var_mode: click.Choice(['measure', 'estimate', 'none']) = (  # noqa: F821
        'estimate',
        'Specifies how to estimate the variance.'
    )
    cheat: click.BOOL = False, 'Enables cheating, i.e., passing the true energy func to the optimizer'
    debug: csobj(click.Choice(['all', 'grad', 'align', 'var', 'meas']), container=DebugSet) = (  # noqa: F821
        '', 'Debug enables logging of expensive statistics.'
    )
